class CheckpointReasoner:
    """
    Stage 2: Metadata Reasoning and Filtering using Gemini.
    """
    
    def __init__(self, model_name: str = "gemini-2.5-flash-exp"):
        """Initialize Gemini for reasoning."""
        genai.configure(api_key="YOUR_GEMINI_API_KEY")
        self.model = genai.GenerativeModel(model_name)
    
    def extract_constraints(self, prompt: str) -> Dict:
        """
        Extract metadata constraints from natural language prompt.
        """
        system_prompt = """Extract metadata constraints from the user prompt for checkpoint selection.
        Return a JSON object with any of these fields if mentioned:
        - month: The month name if a temporal reference is made (e.g., "April")
        - version: Specific version number if mentioned (e.g., "v4", "version 2")
        - subject: The main subject being requested (e.g., "bear", "cat")
        - style: Visual style if specified (e.g., "photorealistic", "cartoon")
        - created_after: Date if "recent" or "latest" is mentioned
        - any other specific attributes mentioned
        
        Example input: "April-trained version of bear"
        Example output: {"month": "April", "subject": "bear"}
        
        Example input: "latest photorealistic cat model"
        Example output: {"subject": "cat", "style": "photorealistic", "created_after": "recent"}
        
        Only include fields that are explicitly mentioned or strongly implied.
        Return only valid JSON."""
        
        response = self.model.generate_content(
            f"{system_prompt}\n\nUser prompt: {prompt}"
        )
        
        try:
            constraints = json.loads(response.text)
            return constraints
        except json.JSONDecodeError:
            print(f"Failed to parse constraints: {response.text}")
            return {}
    
    def filter_and_rank(
        self,
        checkpoints: List[Checkpoint],
        prompt: str,
        constraints: Optional[Dict] = None
    ) -> Tuple[List[Checkpoint], Optional[str]]:
        """
        Filter checkpoints based on constraints and generate clarification if needed.
        
        Returns:
            - Filtered checkpoints
            - Clarification question (if multiple matches)
        """
        if constraints is None:
            constraints = self.extract_constraints(prompt)
        
        # Prepare checkpoint data for Gemini
        checkpoint_data = []
        for cp in checkpoints:
            checkpoint_data.append({
                "id": cp.id,
                "version": cp.version,
                "description": cp.description,
                "subject_types": cp.subject_types,
                "metadata": cp.metadata,
                "similarity_score": getattr(cp, 'similarity_score', None)
            })
        
        filtering_prompt = f"""Given these top-K checkpoints retrieved by semantic similarity, 
        filter and rank them based on the user's constraints and intent.
        
        User prompt: {prompt}
        Extracted constraints: {json.dumps(constraints)}
        
        Checkpoints:
        {json.dumps(checkpoint_data, indent=2)}
        
        Instructions:
        1. Filter checkpoints that match the constraints
        2. If multiple checkpoints match, identify their key differences
        3. Generate a clarification question if needed
        
        Return JSON with:
        {{
            "selected_ids": ["checkpoint_id1", "checkpoint_id2", ...],
            "clarification": "question to ask user if multiple matches" or null,
            "reasoning": "brief explanation of filtering logic"
        }}
        """
        
        response = self.model.generate_content(filtering_prompt)
        
        try:
            result = json.loads(response.text)
            
            # Filter checkpoints based on selected IDs
            selected_checkpoints = []
            for checkpoint_id in result.get("selected_ids", []):
                for cp in checkpoints:
                    if cp.id == checkpoint_id:
                        selected_checkpoints.append(cp)
                        break
            
            clarification = result.get("clarification")
            
            print(f"Reasoning: {result.get('reasoning', 'No reasoning provided')}")
            
            return selected_checkpoints, clarification
            
        except json.JSONDecodeError:
            print(f"Failed to parse filtering result: {response.text}")
            # Fallback: return top checkpoints without filtering
            return checkpoints[:3], None


class CheckpointPipeline:
    """
    Complete pipeline combining all stages.
    """
    
    def __init__(self, gemini_api_key: str):
        """Initialize all components."""
        genai.configure(api_key=gemini_api_key)
        
        self.description_extractor = CheckpointDescriptionExtractor()
        self.reasoner = CheckpointReasoner()
        self.mapper = TriggerTokenMapper()
    
    def process_query(
        self,
        prompt: str,
        top_k_checkpoints: List[Checkpoint],
        auto_select: bool = True
    ) -> Tuple[Checkpoint, str, Optional[str]]:
        """
        Process a query through Stage 2 and 3.
        
        Args:
            prompt: User's natural language prompt
            top_k_checkpoints: Results from Stage 1 (semantic retrieval)
            auto_select: If True, automatically select first checkpoint when multiple match
        
        Returns:
            - selected_checkpoint: Final selected checkpoint
            - mapped_prompt: Prompt with trigger tokens
            - clarification: Optional clarification question
        """
        
        # Stage 2: Reasoning and Filtering
        filtered_checkpoints, clarification = self.reasoner.filter_and_rank(
            top_k_checkpoints,
            prompt
        )
        
        # Handle clarification
        if clarification and len(filtered_checkpoints) > 1:
            if auto_select:
                print(f"Multiple matches found. Auto-selecting first one.")
                print(f"(Would ask user: {clarification})")
                selected_checkpoint = filtered_checkpoints[0]
            else:
                # In production, you would return to user for clarification
                return None, None, clarification
        elif len(filtered_checkpoints) == 1:
            selected_checkpoint = filtered_checkpoints[0]
        elif len(filtered_checkpoints) == 0:
            print("No checkpoints matched constraints, using best semantic match")
            selected_checkpoint = top_k_checkpoints[0]
        else:
            selected_checkpoint = filtered_checkpoints[0]
        
        # Stage 3: Token Mapping
        mapped_prompt = self.mapper.map_prompt_tokens(prompt, selected_checkpoint)
        
        return selected_checkpoint, mapped_prompt, clarification


# Example usage
if __name__ == "__main__":
    extractor = CheckpointDescriptionExtractor()
    description = extractor.process_new_checkpoint(
        checkpoint_path="./models/bear_v4.safetensors",
        subject_type="bear",
        output_metadata_path="./data/checkpoints/bear/v4/metadata.json"
    )
    print(f"Extracted description: {description}")
    
    
    import retrieve_top_k_checkpoints
    
    top_k = retrieve_top_k_checkpoints(
        "Forest scene with April-trained photorealistic bear",
        top_k=10
    )
    
    pipeline = CheckpointPipeline(gemini_api_key="YOUR_KEY")
    checkpoint, mapped_prompt, clarification = pipeline.process_query(
        prompt="Forest scene with April-trained photorealistic bear",
        top_k_checkpoints=top_k,
        auto_select=True
    )
    
    print(f"\nSelected checkpoint: {checkpoint.id}")
    print(f"Original prompt: Forest scene with April-trained photorealistic bear")
    print(f"Mapped prompt: {mapped_prompt}")
    if clarification:
        print(f"Clarification needed: {clarification}")